from argparse import ArgumentParser
import mmap
from struct import unpack

MAGIC = 0x58881688
EXT_MAGIC = 0x58891689


def get_next_partition(buffer, offset):
    magic = unpack('<I', buffer[0:4])[0]
    if magic != MAGIC:
        raise RuntimeError('invalid magic')

    alignment = 16
    image_list_end = 0
    header_size = 512
    data_size = unpack('<I', buffer[4:8])[0]

    if unpack('<I', buffer[48:52])[0] == EXT_MAGIC:
        header_size = unpack('<I', buffer[52:56])[0]
        image_list_end = unpack('<I', buffer[64:68])[0]
        alignment = unpack('<I', buffer[68:72])[0]
        data_size |= unpack('<I', buffer[72:76])[0] << 32

    if image_list_end:
        return None
    else:
        new_offset = offset + header_size + data_size
        if new_offset % alignment != 0:
            new_offset += (alignment - new_offset) % alignment
        return new_offset


def extract_partition2(buffer, output, with_header: bool):
    magic = unpack('<I', buffer[0:4])[0]
    if magic != MAGIC:
        raise RuntimeError('invalid magic')

    header_size = 512
    data_size = unpack('<I', buffer[4:8])[0]
    alignment = 0
    if unpack('<I', buffer[48:52])[0] == EXT_MAGIC:
        header_size = unpack('<I', buffer[52:56])[0]
        data_size |= unpack('<I', buffer[72:76])[0] << 32
        alignment = unpack('<I', buffer[68:72])[0]

    f = open(output, 'wb')
    if not with_header:
        f.write(buffer[header_size:header_size+data_size])
    else:
        data_size += header_size
        data_size += (alignment - (data_size) % alignment)
        f.write(buffer[:data_size])
    f.close()


def extract_partition(buffer, index, output, with_header: bool):
    i = 0
    offset = 0
    while True:
        if i == index:
            extract_partition2(buffer[offset:], output, with_header)
            break
        else:
            n = get_next_partition(buffer[offset:], offset)
            if n == None:
                raise RuntimeError('no such partition: {}'.format(index))
            assert isinstance(n, int)
            offset = n

            i += 1


def main():
    parser = ArgumentParser()
    parser.add_argument('-H', '--with-header', action='store_true')
    parser.add_argument('input')
    parser.add_argument('index', type=int)
    parser.add_argument('output')
    args = parser.parse_args()

    input = open(args.input, 'rb')
    input_buffer = mmap.mmap(input.fileno(), 0, access=mmap.ACCESS_READ)
    extract_partition(input_buffer, args.index, args.output, args.with_header)


if __name__ == '__main__':
    main()
